{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "We have trained a well-trained checkpoint through the `ocr-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import some libraries\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "from swift.llm import (\n",
    "    InferEngine, InferRequest, PtEngine, RequestConfig, get_template, load_dataset, load_image\n",
    ")\n",
    "from swift.utils import get_model_parameter_info, get_logger, seed_everything\n",
    "logger = get_logger()\n",
    "seed_everything(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters for inference\n",
    "last_model_checkpoint = 'output/checkpoint-xxx'\n",
    "\n",
    "# model\n",
    "model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct'  # model_id or model_path\n",
    "system = None\n",
    "infer_backend = 'pt'\n",
    "\n",
    "# dataset\n",
    "dataset = ['AI-ModelScope/LaTeX_OCR#20000']\n",
    "data_seed = 42\n",
    "split_dataset_ratio = 0.01\n",
    "num_proc = 4\n",
    "strict = False\n",
    "\n",
    "# generation_config\n",
    "max_new_tokens = 512\n",
    "temperature = 0\n",
    "stream = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model and template, and load LoRA weights.\n",
    "engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])\n",
    "template = get_template(engine.model_meta.template, engine.processor, default_system=system)\n",
    "# The default mode of the template is 'pt', so there is no need to make any changes.\n",
    "# template.set_mode('pt')\n",
    "\n",
    "model_parameter_info = get_model_parameter_info(engine.model)\n",
    "logger.info(f'model_parameter_info: {model_parameter_info}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Due to the data_seed setting, the validation set here is the same as the validation set used during training.\n",
    "_, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,\n",
    "                              strict=strict, seed=data_seed)\n",
    "val_dataset = val_dataset.select(range(10))  # Take the first 10 items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Streaming inference and save images from the validation set.\n",
    "# The batch processing code can be found here: https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_mllm.py\n",
    "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
    "    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)\n",
    "    gen_list = engine.infer([infer_request], request_config)\n",
    "    query = infer_request.messages[0]['content']\n",
    "    print(f'query: {query}\\nresponse: ', end='')\n",
    "    for resp in gen_list[0]:\n",
    "        if resp is None:\n",
    "            continue\n",
    "        print(resp.choices[0].delta.content, end='', flush=True)\n",
    "    print()\n",
    "\n",
    "from IPython.display import display\n",
    "os.makedirs('images', exist_ok=True)\n",
    "for i, data in enumerate(val_dataset):\n",
    "    image = data['images'][0]\n",
    "    image = load_image(image['bytes'] or image['path'])\n",
    "    image.save(f'images/{i}.png')\n",
    "    display(image)\n",
    "    infer_stream(engine, InferRequest(**data))\n",
    "    print('-' * 50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test_py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
